| | from __future__ import print_function |
| | import torch.backends.cudnn as cudnn |
| | import torch |
| | import torchvision.transforms as transforms |
| |
|
| | import argparse |
| | import os |
| | import random |
| | import sys |
| | import pprint |
| | import datetime |
| | import dateutil |
| | import dateutil.tz |
| |
|
| |
|
| | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) |
| | sys.path.append(dir_path) |
| |
|
| | from miscc.datasets import TextDataset |
| | from miscc.config import cfg, cfg_from_file |
| | from miscc.utils import mkdir_p |
| | from trainer import GANTrainer |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description='Train a GAN network') |
| | parser.add_argument('--cfg', dest='cfg_file', |
| | help='optional config file', |
| | default='birds_stage1.yml', type=str) |
| | parser.add_argument('--gpu', dest='gpu_id', type=str, default='0') |
| | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') |
| | parser.add_argument('--manualSeed', type=int, help='manual seed') |
| | args = parser.parse_args() |
| | return args |
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | if args.cfg_file is not None: |
| | cfg_from_file(args.cfg_file) |
| | if args.gpu_id != -1: |
| | cfg.GPU_ID = args.gpu_id |
| | if args.data_dir != '': |
| | cfg.DATA_DIR = args.data_dir |
| | print('Using config:') |
| | pprint.pprint(cfg) |
| | if args.manualSeed is None: |
| | args.manualSeed = random.randint(1, 10000) |
| | random.seed(args.manualSeed) |
| | torch.manual_seed(args.manualSeed) |
| | if cfg.CUDA: |
| | torch.cuda.manual_seed_all(args.manualSeed) |
| | now = datetime.datetime.now(dateutil.tz.tzlocal()) |
| | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') |
| | output_dir = '../output/%s_%s_%s' % \ |
| | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) |
| |
|
| | num_gpu = len(cfg.GPU_ID.split(',')) |
| | if cfg.TRAIN.FLAG: |
| | dataset = TextDataset(cfg.DATA_DIR, 'train', |
| | rirsize=cfg.RIRSIZE) |
| | assert dataset |
| | |
| | dataloader = torch.utils.data.DataLoader( |
| | dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, |
| | drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) |
| |
|
| | algo = GANTrainer(output_dir) |
| | algo.train(dataloader, cfg.STAGE) |
| | else: |
| | file_path = cfg.EVAL_DIR |
| | algo = GANTrainer(output_dir) |
| | algo.sample(file_path, cfg.STAGE) |
| |
|