| """ |
| 快速训练脚本 - 用于测试和调试 |
| 只使用数据集的前100个样本进行快速多 epoch 测试 |
| """ |
| import argparse |
| import os, sys |
| import math |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.append(BASE_DIR) |
|
|
| import pprint |
| import time |
| import torch |
| import torch.nn.parallel |
| from torch.cuda import amp |
| import torch.backends.cudnn as cudnn |
| import torch.optim |
| import torch.utils.data |
| import torchvision.transforms as transforms |
| import numpy as np |
| from tensorboardX import SummaryWriter |
|
|
| import lib.dataset as dataset |
| from lib.config import cfg |
| from lib.config import update_config |
| from lib.core.loss import get_loss |
| from lib.core.function import train |
| from lib.core.function import validate |
| from lib.core.general import fitness |
| from lib.models import get_net |
| from lib.utils.utils import get_optimizer |
| from lib.utils.utils import save_checkpoint |
| from lib.utils.utils import create_logger, select_device |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Quick train for testing') |
| |
| parser.add_argument('--config', type=str, default='yolov11', |
| help='config to use: default or yolov11') |
| parser.add_argument('--samples', type=int, default=100, |
| help='number of samples to use for quick test') |
| parser.add_argument('--epochs', type=int, default=10, |
| help='number of epochs for quick test') |
| parser.add_argument('--batch-size', type=int, default=4, |
| help='batch size for quick test') |
| parser.add_argument('--yolo-scale', type=str, default='s', |
| choices=['n', 's', 'm', 'l', 'x'], |
| help='YOLOv11 scale (only used if config=yolov11)') |
| parser.add_argument('--freeze-backbone', action='store_true', |
| help='freeze YOLOv11 backbone') |
| parser.add_argument('--workers', type=int, default=0, |
| help='number of data loading workers') |
| |
| args = parser.parse_args() |
| return args |
|
|
|
|
| class SubsetDataset(torch.utils.data.Dataset): |
| """数据集子集包装器""" |
| def __init__(self, dataset, num_samples): |
| self.dataset = dataset |
| self.num_samples = min(num_samples, len(dataset)) |
| |
| def __len__(self): |
| return self.num_samples |
| |
| def __getitem__(self, idx): |
| if idx >= self.num_samples: |
| raise IndexError |
| return self.dataset[idx] |
|
|
|
|
| def main(): |
| args = parse_args() |
| |
| |
| if args.config == 'yolov11': |
| from lib.config.yolov11 import cfg |
| |
| cfg.MODEL.YOLOV11_SCALE = args.yolo_scale |
| cfg.MODEL.YOLOV11_WEIGHTS = f'weights/yolo11{args.yolo_scale}.pt' |
| cfg.MODEL.FREEZE_BACKBONE = args.freeze_backbone |
| else: |
| from lib.config.default import _C as cfg |
| |
| |
| cfg.TRAIN.BEGIN_EPOCH = 0 |
| cfg.TRAIN.END_EPOCH = args.epochs |
| cfg.TRAIN.BATCH_SIZE_PER_GPU = args.batch_size |
| cfg.WORKERS = args.workers |
| cfg.PRINT_FREQ = 5 |
| |
| |
| logger, final_output_dir, tb_log_dir = create_logger( |
| cfg, cfg.LOG_DIR, 'quick_train' |
| ) |
| |
| logger.info("="*80) |
| logger.info("QUICK TRAIN MODE - Testing Configuration") |
| logger.info("="*80) |
| logger.info(f"Config: {args.config}") |
| logger.info(f"Samples: {args.samples}") |
| logger.info(f"Epochs: {args.epochs}") |
| logger.info(f"Batch size: {args.batch_size}") |
| if args.config == 'yolov11': |
| logger.info(f"YOLOv11 scale: {args.yolo_scale}") |
| logger.info(f"Freeze backbone: {args.freeze_backbone}") |
| logger.info("="*80) |
| |
| writer_dict = { |
| 'writer': SummaryWriter(log_dir=tb_log_dir), |
| 'train_global_steps': 0, |
| 'valid_global_steps': 0, |
| } |
| |
| |
| cudnn.benchmark = cfg.CUDNN.BENCHMARK |
| torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC |
| torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED |
| |
| |
| logger.info("Building model...") |
| device = select_device(logger, batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU) |
| |
| if hasattr(cfg.MODEL, 'USE_YOLOV11') and cfg.MODEL.USE_YOLOV11: |
| model = get_net( |
| cfg, |
| yolo_scale=cfg.MODEL.YOLOV11_SCALE, |
| yolo_weights_path=cfg.MODEL.YOLOV11_WEIGHTS, |
| freeze_backbone=cfg.MODEL.FREEZE_BACKBONE |
| ).to(device) |
| else: |
| model = get_net(cfg).to(device) |
| |
| logger.info("Model created successfully") |
|
|
| print("++++++++++++++++++++++") |
| print(model.model[model.detector_index]) |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.info(f"Total parameters: {total_params:,}") |
| logger.info(f"Trainable parameters: {trainable_params:,}") |
| logger.info(f"Frozen parameters: {total_params - trainable_params:,}") |
| |
| |
| criterion = get_loss(cfg, device=device) |
| optimizer = get_optimizer(cfg, model) |
| |
| |
| lf = lambda x: ((1 + math.cos(x * math.pi / cfg.TRAIN.END_EPOCH)) / 2) * \ |
| (1 - cfg.TRAIN.LRF) + cfg.TRAIN.LRF |
| lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) |
| |
| |
| logger.info("Loading dataset...") |
| normalize = transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| ) |
| |
| train_dataset = eval('dataset.' + cfg.DATASET.DATASET)( |
| cfg=cfg, |
| is_train=True, |
| inputsize=cfg.MODEL.IMAGE_SIZE, |
| transform=transforms.Compose([ |
| transforms.ToTensor(), |
| normalize, |
| ]) |
| ) |
| |
| |
| train_dataset = SubsetDataset(train_dataset, args.samples) |
| logger.info(f"Using {len(train_dataset)} training samples") |
| |
| train_loader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU, |
| shuffle=True, |
| num_workers=cfg.WORKERS, |
| pin_memory=cfg.PIN_MEMORY, |
| collate_fn=dataset.AutoDriveDataset.collate_fn |
| ) |
| |
| |
| valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)( |
| cfg=cfg, |
| is_train=False, |
| inputsize=cfg.MODEL.IMAGE_SIZE, |
| transform=transforms.Compose([ |
| transforms.ToTensor(), |
| normalize, |
| ]) |
| ) |
| valid_dataset = SubsetDataset(valid_dataset, args.samples // 2) |
| logger.info(f"Using {len(valid_dataset)} validation samples") |
| |
| valid_loader = torch.utils.data.DataLoader( |
| valid_dataset, |
| batch_size=cfg.TEST.BATCH_SIZE_PER_GPU, |
| shuffle=False, |
| num_workers=cfg.WORKERS, |
| pin_memory=cfg.PIN_MEMORY, |
| collate_fn=dataset.AutoDriveDataset.collate_fn |
| ) |
| |
| |
| scaler = amp.GradScaler(enabled=device.type != 'cpu') |
| |
| |
| logger.info("Starting training...") |
| logger.info("="*80) |
| |
| best_fitness = 0.0 |
| num_batch = len(train_loader) |
| num_warmup = max(round(cfg.TRAIN.WARMUP_EPOCHS * num_batch), 1000) |
| |
| for epoch in range(cfg.TRAIN.BEGIN_EPOCH, cfg.TRAIN.END_EPOCH): |
| logger.info(f"\n{'='*80}") |
| logger.info(f"Epoch {epoch}/{cfg.TRAIN.END_EPOCH-1}") |
| logger.info(f"{'='*80}") |
| |
| |
| train( |
| cfg, train_loader, model, criterion, optimizer, |
| scaler, epoch, num_batch, num_warmup, |
| writer_dict, logger, device |
| ) |
| |
| |
| lr_scheduler.step() |
| |
| |
| if (epoch % cfg.TRAIN.VAL_FREQ == 0 or epoch == cfg.TRAIN.END_EPOCH - 1): |
| logger.info("\nValidating...") |
| da_segment_results, ll_segment_results, detect_results, total_loss, maps, times = validate( |
| epoch, cfg, valid_loader, valid_dataset, model, criterion, |
| final_output_dir, tb_log_dir, writer_dict, logger, device |
| ) |
| |
| |
| fi = fitness(np.array(detect_results).reshape(1, -1)) |
| logger.info(f"Fitness: {fi.item():.4f}") |
| |
| |
| if fi > best_fitness: |
| best_fitness = fi |
| |
| |
| logger.info(f"New best fitness: {best_fitness.item():.4f}") |
| save_checkpoint( |
| epoch= epoch + 1, |
| name='111', |
| model=model, |
| optimizer=optimizer, |
| output_dir=final_output_dir, |
| filename='checkpoint_best.pth', |
| is_best=True |
| ) |
| |
| |
| save_checkpoint( |
| epoch=epoch, |
| name=cfg.MODEL.NAME, |
| model=model, |
| |
| |
| optimizer=optimizer, |
| output_dir=final_output_dir, |
| filename=f'epoch-{epoch}.pth' |
| ) |
| |
| logger.info("\n" + "="*80) |
| logger.info("Training completed!") |
| logger.info(f"Best fitness: {best_fitness.item():.4f}") |
| logger.info(f"Results saved to: {final_output_dir}") |
| logger.info("="*80) |
| |
| writer_dict['writer'].close() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|